from collections.abc import Generator

import pytest

from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType


def _create_tool_node():
    data = ToolNodeData(
        title="Test Tool",
        tool_parameters={},
        provider_id="test_tool",
        provider_type=ToolProviderType.WORKFLOW,
        provider_name="test tool",
        tool_name="test tool",
        tool_label="test tool",
        tool_configurations={},
        plugin_unique_identifier=None,
        desc="Exception handling test tool",
        error_strategy=ErrorStrategy.FAIL_BRANCH,
        version="1",
    )
    variable_pool = VariablePool(
        system_variables={},
        user_inputs={},
    )
    node = ToolNode(
        id="1",
        config={
            "id": "1",
            "data": data.model_dump(),
        },
        graph_init_params=GraphInitParams(
            tenant_id="1",
            app_id="1",
            workflow_type=WorkflowType.WORKFLOW,
            workflow_id="1",
            graph_config={},
            user_id="1",
            user_from=UserFrom.ACCOUNT,
            invoke_from=InvokeFrom.SERVICE_API,
            call_depth=0,
        ),
        graph=Graph(
            root_node_id="1",
            answer_stream_generate_routes=AnswerStreamGenerateRoute(
                answer_dependencies={},
                answer_generate_route={},
            ),
            end_stream_param=EndStreamParam(
                end_dependencies={},
                end_stream_variable_selector_mapping={},
            ),
        ),
        graph_runtime_state=GraphRuntimeState(
            variable_pool=variable_pool,
            start_at=0,
        ),
    )
    return node


class MockToolRuntime:
    def get_merged_runtime_parameters(self):
        pass


def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
    yield from []
    raise ToolInvokeError("oops")


def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
    """Ensure that ToolNode can handle ToolInvokeError when transforming
    messages generated by ToolEngine.generic_invoke.
    """
    tool_node = _create_tool_node()

    # Need to patch ToolManager and ToolEngine so that we don't
    # have to set up a database.
    monkeypatch.setattr(
        "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
    )
    monkeypatch.setattr(
        "core.tools.tool_engine.ToolEngine.generic_invoke",
        lambda *args, **kwargs: mock_message_stream(),
    )

    streams = list(tool_node._run())
    assert len(streams) == 1
    stream = streams[0]
    assert isinstance(stream, RunCompletedEvent)
    result = stream.run_result
    assert isinstance(result, NodeRunResult)
    assert result.status == WorkflowNodeExecutionStatus.FAILED
    assert "oops" in result.error
    assert "Failed to transform tool message:" in result.error
    assert result.error_type == "ToolInvokeError"
